// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation.  All rights reserved.
// Licensed under the MIT License (MIT). See License.txt in the repo root for license information.
// ------------------------------------------------------------

#pragma once

#define STORE_KEYS_ENUMERATOR_TAG 'eaSK'

namespace Data
{
    namespace TStore
    {
        template<typename TKey, typename TValue>
        class StoreKeysEnumerator : public IEnumerator<TKey>
        {
        public:
            typedef KDelegate<ULONG(const TKey & Key)> HashFunctionType;

            static NTSTATUS Create(
                __in IComparer<TKey> & keyComparer,
                __in StoreTransaction<TKey, TValue> & storeTransaction,
                __in HashFunctionType hashFunc,
                __in DifferentialStoreComponent<TKey, TValue> & differentialStoreComponent,
                __in ConsolidationManager<TKey, TValue> & consolidationManager,
                __in LONG64 visibilitySequenceNumber,
                __in KSharedPtr<SnapshotComponent<TKey, TValue>> & snapshotComponentSPtr,
                __in IEnumerator<TKey> & enumerator,
                __in KAllocator & allocator,
                __out KSharedPtr<IEnumerator<TKey>> & result)
            {
                result = _new(STORE_KEYS_ENUMERATOR_TAG, allocator) StoreKeysEnumerator(
                    keyComparer,
                    storeTransaction, 
                    hashFunc,
                    differentialStoreComponent, 
                    consolidationManager, 
                    visibilitySequenceNumber,
                    snapshotComponentSPtr, 
                    enumerator);

                if (!result)
                {
                    return STATUS_INSUFFICIENT_RESOURCES;
                }

                if (!NT_SUCCESS(result->Status()))
                {
                    return (KSharedPtr<IEnumerator<TKey>>(Ktl::Move(result)))->Status();
                }

                return STATUS_SUCCESS;
            }

            TKey Current() override
            {
                return enumeratorSPtr_->Current();
            }

            bool MoveNext() override
            {
                if (isDone_)
                {
                    return false;
                }
                
                while (enumeratorSPtr_->MoveNext())
                {
                    TKey key = enumeratorSPtr_->Current();
                    
                    // TODO:
                    // Since the output sequence should not have duplicate keys, we check for them here
                    // Ideally, there should be no duplicates in the input generated by SortedSequenceMergeEnumerator
                    // Currently, SortedSequenceMergeEnumerator may generate duplicates in some cases because the 
                    // de-duplication code is blocked by a missing API in KPriorityQueue
                    if (isPreviousSet_ && keyComparerSPtr_->Compare(key, previousKey_) == 0)
                    {
                        continue;
                    }

                    if (IsKeyValid(key))
                    {
                        return true;
                    }
                }

                isDone_ = true;
                return false;
            }

        private:
            bool IsKeyValid(TKey & key)
            {
                // Check to see if this key was already added as part of this store transaction.
                auto writeset = transactionSPtr_->GetComponent(func_);
                KInvariant(writeset != nullptr);
                auto versionedItem = writeset->Read(key);

                // Not in writeset, check in differential
                if (versionedItem == nullptr)
                {
                    if (visibilitySequenceNumber_ != Constants::InvalidLsn)
                    {
                        versionedItem = differentialSPtr_->Read(key, visibilitySequenceNumber_);
                    }
                    else
                    {
                        versionedItem = differentialSPtr_->Read(key);
                    }
                }

                // Not in differential, check in snapshot and consolidated
                if (versionedItem == nullptr)
                {
                    KSharedPtr<VersionedItem<TValue>> snapshotItem = nullptr;
                    KSharedPtr<VersionedItem<TValue>> consolidatedItem = nullptr;

                    if (snapshotSPtr_ != nullptr)
                    {
                        snapshotItem = snapshotSPtr_->Read(key, visibilitySequenceNumber_);
                    }

                    if (visibilitySequenceNumber_ != Constants::InvalidLsn)
                    {
                        consolidatedItem = consolidatedSPtr_->Read(key, visibilitySequenceNumber_);
                    }
                    else
                    {
                        consolidatedItem = consolidatedSPtr_->Read(key);
                    }

                    // Not found in snapshot, found in consolidated, and sequence number is good
                    if (snapshotItem == nullptr && consolidatedItem != nullptr)
                    {
                        versionedItem = consolidatedItem;
                    }
                    // Found in snapshot, not in consolidated, and sequence number is good
                    else if (snapshotItem != nullptr && consolidatedItem == nullptr)
                    {
                        versionedItem = snapshotItem;
                    }
                    // Found in snapshot and consolidated
                    else if (snapshotItem != nullptr && consolidatedItem != nullptr)
                    {
                        versionedItem = snapshotItem->GetVersionSequenceNumber() > consolidatedItem->GetVersionSequenceNumber() ? snapshotItem : consolidatedItem;
                    }
                }
                
                
                // Not found or key has been deleted
                if (versionedItem == nullptr || versionedItem->GetRecordKind() == RecordKind::DeletedVersion)
                {
                    return false;
                }
                
                return true;
            }

            StoreKeysEnumerator(
                __in IComparer<TKey> & keyComparer,
                __in StoreTransaction<TKey, TValue> & storeTransaction,
                __in HashFunctionType hashFunc,
                __in DifferentialStoreComponent<TKey, TValue> & differentialStoreComponent,
                __in ConsolidationManager<TKey, TValue> & consolidationManager,
                __in LONG64 visibilitySequenceNumber,
                __in KSharedPtr<SnapshotComponent<TKey, TValue>> & snapshotComponentSPtr,
                __in IEnumerator<TKey> & enumerator);
        
            KSharedPtr<IComparer<TKey>> keyComparerSPtr_;
            KSharedPtr<StoreTransaction<TKey, TValue>> transactionSPtr_;
            KSharedPtr<DifferentialStoreComponent<TKey, TValue>> differentialSPtr_;
            KSharedPtr<ConsolidationManager<TKey, TValue>> consolidatedSPtr_;
            LONG64 visibilitySequenceNumber_;
            KSharedPtr<SnapshotComponent<TKey, TValue>> snapshotSPtr_;
            KSharedPtr<IEnumerator<TKey>> enumeratorSPtr_;
            HashFunctionType func_;

            bool isDone_ = false;
            bool isPreviousSet_ = false;
            TKey previousKey_;
        };
        
        template<typename TKey, typename TValue>
        StoreKeysEnumerator<TKey, TValue>::StoreKeysEnumerator(
            __in IComparer<TKey> & keyComparer,
            __in StoreTransaction<TKey, TValue> & storeTransaction,
            __in HashFunctionType hashFunc,
            __in DifferentialStoreComponent<TKey, TValue> & differentialStoreComponent,
            __in ConsolidationManager<TKey, TValue> & consolidationManager,
            __in LONG64 visibilitySequenceNumber,
            __in KSharedPtr<SnapshotComponent<TKey, TValue>> & snapshotComponentSPtr,
            __in IEnumerator<TKey> & enumerator) :
            keyComparerSPtr_(&keyComparer),
            transactionSPtr_(&storeTransaction),
            differentialSPtr_(&differentialStoreComponent),
            consolidatedSPtr_(&consolidationManager),
            snapshotSPtr_(snapshotComponentSPtr),
            enumeratorSPtr_(&enumerator),
            func_(hashFunc),
            visibilitySequenceNumber_(visibilitySequenceNumber)
        {
        }
    }
}
